# import argparse
import sys
from pathlib import Path
from pytorch_lightning.cli import LightningCLI
from PIL import Image

# For streaming
import yaml
from copy import deepcopy
from typing import List, Optional
from jsonargparse.typing import restricted_string_type


# --------------------------------------
# ----------- For Streaming ------------
# --------------------------------------
class CustomCLI(LightningCLI):
    def add_arguments_to_parser(self, parser):
        parser.add_argument("--result_fol", type=Path,
                            help="Set the path to the result folder", default="results")
        parser.add_argument("--exp_name", type=str, help="Experiment name")
        parser.add_argument("--run_name", type=str,
                            help="Current run name")
        parser.add_argument("--prompts", type=Optional[List[str]])
        parser.add_argument("--scale_lr", type=bool,
                            help="Scale lr", default=False)
        CodeType = restricted_string_type(
            'CodeType', '(medium)|(high)|(highest)')
        parser.add_argument("--matmul_precision", type=CodeType)
        parser.add_argument("--ckpt", type=Path,)
        parser.add_argument("--n_predictions", type=int)
        return parser

def remove_value(dictionary, x):
    for key, value in list(dictionary.items()):
        if key == x:
            del dictionary[key]
        elif isinstance(value, dict):
            remove_value(value, x)
    return dictionary

def legacy_transformation(cfg: yaml):
    cfg = deepcopy(cfg)
    cfg["trainer"]["devices"] = "1"
    cfg["trainer"]['num_nodes'] = 1

    if not "class_path" in cfg["model"]["inference_params"]:
        cfg["model"]["inference_params"] = {
            "class_path": "model.pl_module_params.InferenceParams", "init_args": cfg["model"]["inference_params"]}
    return cfg


# ---------------------------------------------
# ----------- For enhancement -----------
# ---------------------------------------------
def add_margin(pil_img, top, right, bottom, left, color):
    width, height = pil_img.size
    new_width = width + right + left
    new_height = height + top + bottom
    result = Image.new(pil_img.mode, (new_width, new_height), color)
    result.paste(pil_img, (left, top))
    return result

def resize_to_fit(image, size):
    W, H = size
    w, h = image.size
    if H / h > W / w:
        H_ = int(h * W / w)
        W_ = W
    else:
        W_ = int(w * H / h)
        H_ = H
    return image.resize((W_, H_))

def pad_to_fit(image, size):
    W, H = size
    w, h = image.size
    pad_h = (H - h) // 2
    pad_w = (W - w) // 2
    return add_margin(image, pad_h, pad_w, pad_h, pad_w, (0, 0, 0))

def resize_and_keep(pil_img):
    myheight = 576
    hpercent = (myheight/float(pil_img.size[1]))
    wsize = int((float(pil_img.size[0])*float(hpercent)))
    pil_img = pil_img.resize((wsize, myheight))
    return pil_img

def center_crop(pil_img):
    width, height = pil_img.size
    new_width = 576
    new_height = 576

    left = (width - new_width)/2
    top = (height - new_height)/2
    right = (width + new_width)/2
    bottom = (height + new_height)/2

    # Crop the center of the image
    pil_img = pil_img.crop((left, top, right, bottom))
    return pil_img